"""An old implementation of AlexNet used for Places365."""
import numpy
from collections import OrderedDict

from torch import nn


class AlexNet(nn.Sequential):

    def __init__(self,
                 num_classes=None,
                 include_lrn=False,
                 split_groups=True,
                 include_dropout=False):
        w = [3, 96, 256, 384, 384, 256, 4096, 4096, 365]
        if num_classes is not None:
            w[-1] = num_classes
        if split_groups is True:
            groups = [1, 2, 1, 2, 2]
        else:
            groups = [1, 1, 1, 1, 1]
        sequence = OrderedDict()
        for name, module in [('conv1',
                              nn.Conv2d(w[0],
                                        w[1],
                                        kernel_size=11,
                                        stride=4,
                                        groups=groups[0],
                                        bias=True)),
                             ('relu1', nn.ReLU(inplace=True)),
                             ('pool1', nn.MaxPool2d(kernel_size=3, stride=2)),
                             ('lrn1', LRN(local_size=5,
                                          alpha=0.0001,
                                          beta=0.75)),
                             ('conv2',
                              nn.Conv2d(w[1],
                                        w[2],
                                        kernel_size=5,
                                        padding=2,
                                        groups=groups[1],
                                        bias=True)),
                             ('relu2', nn.ReLU(inplace=True)),
                             ('pool2', nn.MaxPool2d(kernel_size=3, stride=2)),
                             ('lrn2', LRN(local_size=5,
                                          alpha=0.0001,
                                          beta=0.75)),
                             ('conv3',
                              nn.Conv2d(w[2],
                                        w[3],
                                        kernel_size=3,
                                        padding=1,
                                        groups=groups[2],
                                        bias=True)),
                             ('relu3', nn.ReLU(inplace=True)),
                             ('conv4',
                              nn.Conv2d(w[3],
                                        w[4],
                                        kernel_size=3,
                                        padding=1,
                                        groups=groups[3],
                                        bias=True)),
                             ('relu4', nn.ReLU(inplace=True)),
                             ('conv5',
                              nn.Conv2d(w[4],
                                        w[5],
                                        kernel_size=3,
                                        padding=1,
                                        groups=groups[4],
                                        bias=True)),
                             ('relu5', nn.ReLU(inplace=True)),
                             ('pool5', nn.MaxPool2d(kernel_size=3, stride=2)),
                             ('flatten', Vectorize()),
                             ('fc6', nn.Linear(w[5] * 6 * 6, w[6], bias=True)),
                             ('relu6', nn.ReLU(inplace=True)),
                             ('dropout6', nn.Dropout()),
                             ('fc7', nn.Linear(w[6], w[7], bias=True)),
                             ('relu7', nn.ReLU(inplace=True)),
                             ('dropout7', nn.Dropout()),
                             ('fc8', nn.Linear(w[7], w[8]))]:
            if not include_lrn and name.startswith('lrn'):
                continue
            if not include_dropout and name.startswith('drop'):
                continue
            sequence[name] = module
        super(AlexNet, self).__init__(sequence)


class LRN(nn.Module):

    def __init__(self,
                 local_size=1,
                 alpha=1.0,
                 beta=0.75,
                 ACROSS_CHANNELS=True):
        super(LRN, self).__init__()
        self.ACROSS_CHANNELS = ACROSS_CHANNELS
        if ACROSS_CHANNELS:
            self.average = nn.AvgPool3d(kernel_size=(local_size, 1, 1),
                                        stride=1,
                                        padding=(int(
                                            (local_size - 1.0) / 2), 0, 0))
        else:
            self.average = nn.AvgPool2d(kernel_size=local_size,
                                        stride=1,
                                        padding=int((local_size - 1.0) / 2))
        self.alpha = alpha
        self.beta = beta

    def forward(self, x):
        if self.ACROSS_CHANNELS:
            div = x.pow(2).unsqueeze(1)
            div = self.average(div).squeeze(1)
            div = div.mul(self.alpha).add(1.0).pow(self.beta)
        else:
            div = x.pow(2)
            div = self.average(div)
            div = div.mul(self.alpha).add(1.0).pow(self.beta)
        x = x.div(div)
        return x


class Vectorize(nn.Module):

    def __init__(self):
        super(Vectorize, self).__init__()

    def forward(self, x):
        x = x.view(x.size(0), int(numpy.prod(x.size()[1:])))
        return x
